original source: golbin's github
위 원본 코드를 약간의 주석을 추가하여 jupyter-notebook으로 옮겼습니다.
In [17]:
# GAN 모델을 이용해 단순히 랜덤한 숫자를 생성하는 아닌,
# 원하는 손글씨 숫자를 생성하는 모델을 만들어봅니다.
import tensorflow as tf
import matplotlib.pylab as plt
import seaborn
import numpy as np
In [2]:
# tensorflow에서 제공하는 mnist 데이터 로딩
from tensorflow.examples.tutorials.mnist import input_data
# one_hot=True -> label이 one-hot encoding됨
mnist = input_data.read_data_sets("./mnist/data/", one_hot=True)
In [3]:
n_hidden = 256 # 히든 레이어의 뉴런 갯수
n_input = 28 * 28 # 입력 크기: 28 * 28 크기의 이미지를 1차원 벡터로 입력
n_noise = 128 # 생성기의 입력으로 들어갈 노이즈의 크기
n_class = 10 # 레이블 갯수 0부터 9까지 총 10개
In [4]:
X = tf.placeholder(dtype=tf.float32, shape=[None, n_input]) # 이미지 입력
Y = tf.placeholder(dtype=tf.float32, shape=[None, n_class]) # 0~9까지 이미지에 해당하는 레이블 입력
Z = tf.placeholder(dtype=tf.float32, shape=[None, n_noise]) # 노이즈 입력
In [5]:
def generator(noise, labels):
with tf.variable_scope('generator'):
# noise 값에 labels 정보를 추가합니다.
inputs = tf.concat([noise, labels], 1)
# TensorFlow 에서 제공하는 유틸리티 함수를 이용해 신경망을 매우 간단하게 구성할 수 있습니다.
hidden = tf.layers.dense(inputs, n_hidden, activation=tf.nn.relu)
output = tf.layers.dense(hidden, n_input, activation=tf.nn.sigmoid)
return output
In [6]:
def discriminator(inputs, labels, reuse=None):
with tf.variable_scope('discriminator') as scope:
# 노이즈에서 생성한 이미지와 실제 이미지를 판별하는 모델의 변수를 동일하게 하기 위해,
# 이전에 사용되었던 변수를 재사용하도록 합니다.
if reuse:
scope.reuse_variables()
inputs = tf.concat([inputs, labels], 1)
hidden = tf.layers.dense(inputs, n_hidden, activation=tf.nn.relu)
output = tf.layers.dense(hidden, 1, activation=None)
return output
In [7]:
G = generator(Z, Y)
D_real = discriminator(X, Y)
D_gene = discriminator(G, Y, True)
In [8]:
# 손실함수는 다음을 참고하여 GAN 논문에 나온 방식과는 약간 다르게 작성하였습니다.
# http://bamos.github.io/2016/08/09/deep-completion/
# 진짜 이미지를 판별하는 D_real 값은 1에 가깝도록,
# 가짜 이미지를 판별하는 D_gene 값은 0에 가깝도록 하는 손실 함수입니다.
loss_D_real = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=D_real, labels=tf.ones_like(D_real)))
loss_D_gene = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=D_gene, labels=tf.zeros_like(D_gene)))
# loss_D_real 과 loss_D_gene 을 더한 뒤 이 값을 최소화 하도록 최적화합니다.
loss_D = loss_D_real + loss_D_gene
# 가짜 이미지를 진짜에 가깝게 만들도록 생성망을 학습시키기 위해, D_gene 을 최대한 1에 가깝도록 만드는 손실함수입니다.
loss_G = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=D_gene, labels=tf.ones_like(D_gene)))
In [9]:
# TensorFlow 에서 제공하는 유틸리티 함수를 이용해
# discriminator 와 generator scope 에서 사용된 변수들을 쉽게 가져올 수 있습니다.
vars_D = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
vars_G = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
# 학습시에 var_list에 명시한 변수들만 조정하게됨
# AdamOptimizer에 parameter로 learning_rate를 주지 않으면 기본값 0.001로 학습
train_D = tf.train.AdamOptimizer().minimize(loss_D, var_list=vars_D)
train_G = tf.train.AdamOptimizer().minimize(loss_G, var_list=vars_G)
In [10]:
# noise generate function
def get_noise(batch_size, n_noise):
return np.random.uniform(-1., 1., size=[batch_size, n_noise])
In [13]:
total_epoch = 100 # 학습할 epoch 횟수
batch_size = 100 # 배치 크기
In [21]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
total_batch = int(mnist.train.num_examples/batch_size)
loss_val_D, loss_val_G = 0, 0
for epoch in range(total_epoch):
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
noise = get_noise(batch_size, n_noise)
_, loss_val_D = sess.run([train_D, loss_D],
feed_dict={X: batch_xs, Y: batch_ys, Z: noise})
_, loss_val_G = sess.run([train_G, loss_G],
feed_dict={Y: batch_ys, Z: noise})
print('Epoch:', '%04d' % epoch,
'D loss: {:.4}'.format(loss_val_D),
'G loss: {:.4}'.format(loss_val_G))
#########
# 학습이 되어가는 모습을 보기 위해 주기적으로 레이블에 따른 이미지를 생성하여 저장
######
if epoch == 0 or (epoch + 1) % 10 == 0:
sample_size = 10
noise = get_noise(sample_size, n_noise)
samples = sess.run(G,
feed_dict={Y: mnist.test.labels[:sample_size],
Z: noise})
fig, ax = plt.subplots(2, sample_size, figsize=(sample_size, 2))
for i in range(sample_size):
ax[0][i].set_axis_off()
ax[1][i].set_axis_off()
ax[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)), cmap='gray')
ax[1][i].imshow(np.reshape(samples[i], (28, 28)), cmap='gray')
plt.show()
In [ ]: